#!/usr/bin/env python
import pointcloud
import freenect
import cv
import frame_convert as fc
import numpy as np
import scipy
import time
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import scipy.ndimage as ndi
import scipy.io
import sys
import combiner2 as combiner

rownum = 480; colnum = 640
frames = 1
sigma = 3

def testcombine(mean = np.array([-0.02, 0, 0.8]), angle = 45.):
  global Pi, pointclouds
  # confidence in initial guess
  confidence = 1
  
  pointclouds = []
  Rc = np.diag([1.,1.,1.])
  tc = np.array([0.,0.,0.])[None,:]
    
  for i in range(0,20):
    print 'iteration %d' % i
    fname = 'data/case/f%d.mat' % (i+1)
   
    rot = -angle*np.pi/180
    rotmat = np.array([[np.cos(i*rot), 0, -np.sin(i*rot)],
                        [0 , 1, 0],
                        [np.sin(i*rot), 0,  np.cos(i*rot)]] )

    mat = scipy.io.loadmat(fname)

    cref = np.zeros((480,640,3))
    if i == 0:
      cref[:,:,0] = 255
    else:
      cref[:,:,2] = 255

    newcloud = pointcloud.pointcloud(mat['depth'],  mat['rgb'])
    newcloud.vertex -= mean
    newcloud.clipCube(bottom = 0.0, left = -.3, right = .3)  
    

    if i > 0:
      newcloud.vertex = np.dot(newcloud.vertex, Rc) + tc
      i1 = prevcloud.vertex.shape[0]
      r1 = (np.random.rand(2000) * i1).astype(int)
      i2 = newcloud.vertex.shape[0]
      r2 = (np.random.rand(2000) * i2).astype(int)
      prevsample = prevcloud.vertex[r1,:]
      nextsample = newcloud.vertex[r2,:]

      R, t = combiner.probcombine(prevsample.T,\
             nextsample.T, conf = confidence)

      Rc = np.dot(Rc, R.T)
      tc = np.dot(tc, R.T) + t.T
      
      newcloud.vertex = np.dot(newcloud.vertex, R.T) + t.T
      #newcloud.color[r2,:] *= Pi[:,None]/Pi.max()
      #totalcloud.vertex = totalsample
    
    pointclouds += [newcloud]
    prevcloud = newcloud
  # end for

  totalcloud = pointcloud.pointcloud()
  for cloud in pointclouds:
    totalcloud += cloud

  serial = time.time()
  f = open('poly/poly%d.ply' % serial,'w')
  
  totalcloud.toPly(f)
  f.close()


def combineseries(clouds = 'case', trans = 'bear1301126296.mat', \
start = 0, numframe = None, mean = np.array([-0.02, 0, 0.8]), connectends = 0):
  global tss

  totalcloud = pointcloud.pointcloud()


  mats = scipy.io.loadmat('trans/%s' % trans)
  Rs = mats['Rs']; ts = mats['ts']
  
  if connectends: Rs,ts2 = combiner.makeprodidentity(Rs,ts,connectends)

  if start == 0 and numframe == None: numframe = Rs.shape[0]
  
  tss = np.zeros((Rs.shape[0],1,3))
  Rss = np.zeros((Rs.shape[0],3,3))
  pointclouds = []
  Rc = np.diag([1.,1.,1.])
  tc = np.array([0.,0.,0.])[None,:]
  
  for i in range(0,Rs.shape[0]):
    print 'iteration %d' % i
    tss[i] = tc
    Rss[i] = Rc
    tc = np.dot(ts[i].T, Rc) + tc
    Rc = np.dot(Rs[i].T, Rc)

  if connectends:
    diff = tss[0] - tss[numframe-1]
    for i in range(0, numframe):
      tss[i] = tss[i] + (i+1)*diff/(numframe)
    
  for j in range(0,numframe):
    i = j+start
    fname = 'data/%s/f%d.mat' % (clouds, i % numframe + 1)
    mat = scipy.io.loadmat(fname)
    newcloud = pointcloud.pointcloud(mat['depth'],  mat['rgb'])
    newcloud.vertex -= mean
    newcloud.clipCube(bottom = 0.05, left = -.3, right = .3)
    newcloud.vertex = np.dot(newcloud.vertex, Rss[i]) + tss[i]
    pointclouds += [newcloud]

  totalcloud = pointcloud.pointcloud()
  for cloud in pointclouds:
    totalcloud += cloud

  serial = time.time()
  f = open('poly/poly%d.ply' % serial,'w')
  
  totalcloud.toPly(f)
  f.close()


def gettransform(mean = np.array([-0.02, 0, 0.8]), angle = 45., dataset = 'bear'):
  global Pi, pointclouds
  # confidence in initial guess
  confidence = 5
  numframe = 18
  pointclouds = []
  numsample = 5000

  Rs = np.zeros((numframe, 3, 3))   
  ts = np.zeros((numframe, 3, 1))
 
  for i in range(0,numframe+1):
    print 'iteration %d' % i
    fname = 'data/%s/f%d.mat' % (dataset, i % numframe + 1)

    mat = scipy.io.loadmat(fname)
    newcloud = pointcloud.pointcloud(mat['depth'],  mat['rgb'])
    newcloud.vertex -= mean
    newcloud.clipCube(bottom = 0.06, left = -.3, right = .3)  
    

    if i > 0:
      i1 = prevcloud.vertex.shape[0]
      r1 = (np.random.rand(numsample) * i1).astype(int)
      i2 = newcloud.vertex.shape[0]
      r2 = (np.random.rand(numsample) * i2).astype(int)
      prevsample = prevcloud.vertex[r1,:]
      nextsample = newcloud.vertex[r2,:]

      R, t = combiner.probcombine(prevsample.T,\
             nextsample.T, conf = confidence)

      Rs[i-1] = R; ts[i-1] = t
          
    pointclouds += [newcloud]
    prevcloud = newcloud
  # end for

  serial = time.time()
  scipy.io.savemat('trans/%s%d.mat' % (dataset, serial), {'Rs':Rs, 'ts':ts});
  

def combine(mean = np.array([-0.02, 0.8]), angle = 45.):
  totalcloud = pointcloud.pointcloud()

  xz = np.array([0,2])

  for i in range(8):
    fname = 'data/soysauce/f%d.mat' % (i+1)
    rot = -angle*np.pi/180
    rotmat = np.array([[np.cos(i*rot), -np.sin(i*rot)],
                        [np.sin(i*rot),  np.cos(i*rot)]] )
    mat = scipy.io.loadmat(fname)
    newcloud = pointcloud.pointcloud(mat['depth'], mat['rgb'])
    newcloud.clipCube(bottom = -0.01)
    newcloud.vertex[:,xz] = np.dot(newcloud.vertex[:,xz] - mean, rotmat)
    totalcloud += newcloud
  serial = time.time()
  f = open('poly/poly%d.ply' % serial,'w')
  totalcloud.clipCube(bottom = -0.01)
  totalcloud.toPly(f)
  f.close()
  

def captureimage():
  global frames

  depthframes = np.zeros((frames, rownum, colnum))
  rgbframes = np.zeros((frames, rownum, colnum, 3))

  for i in range(frames):
    depthframes[i] = freenect.sync_get_depth()[0]
    rgbframes[i] = freenect.sync_get_video()[0]
    arargb = freenect.sync_get_video()[0]
    time.sleep(0.05)

  arargb   = fc.robustavg(rgbframes)
  aradepth = fc.robustavg(depthframes)
  serial = time.time()

  cv.SaveImage('img/depth%d.png' % serial, fc.depth_cv(aradepth.astype(int)))
  cv.SaveImage('img/video%d.png' % serial, fc.video_cv(arargb.astype(np.uint8)))
  #f = open('poly/poly%d.ply' % serial,'w')
  
    
  meterdepth = fc.meter_depth(aradepth)
  #newrgb2 = fc.matchrgb2(meterdepth, arargb)
  newrgb = fc.matchrgb(meterdepth, arargb)
  
  #meterdepth = ndi.gaussian_filter(fc.meter_depth(aradepth), [sigma, sigma])
  
  meterdepth[meterdepth > 1.5] = -1.
  meterdepth[meterdepth < 0.5] = -1.
  scipy.io.savemat('data/aligned%d.mat' % serial, {'depth':meterdepth, 'rgb':newrgb})

  #plt.imshow(meterdepth, cmap = cm.gray)
  #plt.show()

  #ConvertPLY.ConvertPLY(f, meterdepth * 600, newrgb)
  #f.close()


count = 0

def showlive():
  global count, frames
  cv.NamedWindow('Depth')
  cv.NamedWindow('Video')
  cv.MoveWindow('Depth', 100, 100)
  cv.MoveWindow('Video', 745, 100)

  print('Press ESC in window to stop')
  print('Press Space to convert current to PLY')
  print('Press k to stop live capture')

  while 1:
      imgdepth = fc.depth_cv(freenect.sync_get_depth()[0])
      imgvideo = fc.video_cv(freenect.sync_get_video()[0])

      cv.ShowImage('Depth', imgdepth)
      cv.ShowImage('Video', imgvideo)

      inp = cv.WaitKey(100)

      if inp != -1:
        inp = chr(inp % 1048576)
        if inp == ' ': # space for capture and convert
          print 'capturing images'
          captureimage()
          print 'done capturing'
        elif inp.isdigit():
          frames = ord(inp) - ord('0')
          print 'setting the number of frames to capture to %d' % frames
        elif inp == 'k':
          break
      count = count + 1

  cv.DestroyWindow('Depth')
  cv.DestroyWindow('Video')


if __name__ == '__main__':
  argc = len(sys.argv)
  if argc == 2:
    if sys.argv[1] == 'c':
      captureimage()
    elif sys.argv[1] == 't':
      testcombine()
    elif sys.argv[1] == 'b':
      combine()
    else:
      showlive()
